Skip to content

Conversation

@jpienaar
Copy link
Member

Enables adding instrumentation to pass manager that can track/flag invariants. This would be useful for cases where one some tighter requirements than the general dialects or for a phase of conversion that elsewhere.

It would enable making verify also just a regular instrumentation I believe, but also a non-goal as that is a first class concept and baseline for the ops and passes.

Would have enabled some of the requirements of https://discourse.llvm.org/t/pre-verification-logic-before-running-conversion-pass-in-mlir/88318/10 .

@jpienaar jpienaar requested review from ftynse and joker-eph October 13, 2025 05:05
@llvmbot llvmbot added mlir:core MLIR Core Infrastructure mlir labels Oct 13, 2025
@llvmbot
Copy link
Member

llvmbot commented Oct 13, 2025

@llvm/pr-subscribers-mlir-core

@llvm/pr-subscribers-mlir

Author: Jacques Pienaar (jpienaar)

Changes

Enables adding instrumentation to pass manager that can track/flag invariants. This would be useful for cases where one some tighter requirements than the general dialects or for a phase of conversion that elsewhere.

It would enable making verify also just a regular instrumentation I believe, but also a non-goal as that is a first class concept and baseline for the ops and passes.

Would have enabled some of the requirements of https://discourse.llvm.org/t/pre-verification-logic-before-running-conversion-pass-in-mlir/88318/10 .


Full diff: https://github.com/llvm/llvm-project/pull/163126.diff

4 Files Affected:

  • (modified) mlir/include/mlir/Pass/Pass.h (+4)
  • (modified) mlir/include/mlir/Pass/PassInstrumentation.h (+2)
  • (modified) mlir/lib/Pass/Pass.cpp (+20-13)
  • (modified) mlir/unittests/Pass/PassManagerTest.cpp (+98)
diff --git a/mlir/include/mlir/Pass/Pass.h b/mlir/include/mlir/Pass/Pass.h
index 16893c6db87b1..f0b0979a81ee3 100644
--- a/mlir/include/mlir/Pass/Pass.h
+++ b/mlir/include/mlir/Pass/Pass.h
@@ -17,6 +17,7 @@
 #include <optional>
 
 namespace mlir {
+class PassInstrumentation;
 namespace detail {
 class OpToOpPassAdaptor;
 struct OpPassManagerImpl;
@@ -334,6 +335,9 @@ class Pass {
 
   /// Allow access to 'passOptions'.
   friend class PassInfo;
+
+  /// Allow access to 'signalPassFailure'.
+  friend class PassInstrumentation;
 };
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Pass/PassInstrumentation.h b/mlir/include/mlir/Pass/PassInstrumentation.h
index 917bac4b22288..25a8e77be75ee 100644
--- a/mlir/include/mlir/Pass/PassInstrumentation.h
+++ b/mlir/include/mlir/Pass/PassInstrumentation.h
@@ -80,6 +80,8 @@ class PassInstrumentation {
   /// name of the analysis that was computed, its TypeID, as well as the
   /// current operation being analyzed.
   virtual void runAfterAnalysis(StringRef name, TypeID id, Operation *op) {}
+
+  static void signalPassFailure(Pass *pass);
 };
 
 /// This class holds a collection of PassInstrumentation objects, and invokes
diff --git a/mlir/lib/Pass/Pass.cpp b/mlir/lib/Pass/Pass.cpp
index 521c7c6be17b6..17ac475b42f4b 100644
--- a/mlir/lib/Pass/Pass.cpp
+++ b/mlir/lib/Pass/Pass.cpp
@@ -599,17 +599,20 @@ LogicalResult OpToOpPassAdaptor::run(Pass *pass, Operation *op,
   if (pi)
     pi->runBeforePass(pass, op);
 
-  bool passFailed = false;
-  op->getContext()->executeAction<PassExecutionAction>(
-      [&]() {
-        // Invoke the virtual runOnOperation method.
-        if (auto *adaptor = dyn_cast<OpToOpPassAdaptor>(pass))
-          adaptor->runOnOperation(verifyPasses);
-        else
-          pass->runOnOperation();
-        passFailed = pass->passState->irAndPassFailed.getInt();
-      },
-      {op}, *pass);
+  bool passFailed = pass->passState->irAndPassFailed.getInt();
+  if (!passFailed) {
+    op->getContext()->executeAction<PassExecutionAction>(
+        [&]() {
+          // Invoke the virtual runOnOperation method.
+          if (auto *adaptor = dyn_cast<OpToOpPassAdaptor>(pass))
+            adaptor->runOnOperation(verifyPasses);
+          else
+            pass->runOnOperation();
+          passFailed = pass->passState->irAndPassFailed.getInt();
+        },
+        {op}, *pass);
+  }
+
 
   // Invalidate any non preserved analyses.
   am.invalidate(pass->passState->preservedAnalyses);
@@ -640,10 +643,12 @@ LogicalResult OpToOpPassAdaptor::run(Pass *pass, Operation *op,
 
   // Instrument after the pass has run.
   if (pi) {
-    if (passFailed)
+    if (passFailed) {
       pi->runAfterPassFailed(pass, op);
-    else
+    } else {
       pi->runAfterPass(pass, op);
+      passFailed = passFailed || pass->passState->irAndPassFailed.getInt();
+    }
   }
 
   // Return if the pass signaled a failure.
@@ -1198,6 +1203,8 @@ void PassInstrumentation::runBeforePipeline(
 void PassInstrumentation::runAfterPipeline(
     std::optional<OperationName> name, const PipelineParentInfo &parentInfo) {}
 
+void PassInstrumentation::signalPassFailure(Pass *pass) { pass->signalPassFailure(); }
+
 //===----------------------------------------------------------------------===//
 // PassInstrumentor
 //===----------------------------------------------------------------------===//
diff --git a/mlir/unittests/Pass/PassManagerTest.cpp b/mlir/unittests/Pass/PassManagerTest.cpp
index 7e618811eabf4..86c793384db11 100644
--- a/mlir/unittests/Pass/PassManagerTest.cpp
+++ b/mlir/unittests/Pass/PassManagerTest.cpp
@@ -14,6 +14,7 @@
 #include "mlir/IR/BuiltinOps.h"
 #include "mlir/IR/Diagnostics.h"
 #include "mlir/Pass/Pass.h"
+#include "mlir/Pass/PassInstrumentation.h"
 #include "gtest/gtest.h"
 
 #include <memory>
@@ -117,6 +118,103 @@ struct AddSecondAttrFunctionPass
   }
 };
 
+/// PassInstrumentation to count pass callbacks and signal pass failures.
+struct TestPassInstrumentation : public PassInstrumentation {
+  int beforePassCallbackCount = 0;
+  int afterPassCallbackCount = 0;
+  int afterPassFailedCallbackCount = 0;
+
+  bool failBeforePass = false;
+  bool failAfterPass = false;
+
+  void runBeforePass(Pass *pass, Operation *op) override {
+    if (pass->getTypeID() != TypeID::get<AddAttrFunctionPass>()) return;
+
+    ++beforePassCallbackCount;
+    if (failBeforePass)
+      signalPassFailure(pass);
+  }
+  void runAfterPass(Pass *pass, Operation *op) override {
+    if (pass->getTypeID() != TypeID::get<AddAttrFunctionPass>()) return;
+
+    ++afterPassCallbackCount;
+    if (failAfterPass)
+      signalPassFailure(pass);
+  }
+  void runAfterPassFailed(Pass *pass, Operation *op) override {
+    if (pass->getTypeID() != TypeID::get<AddAttrFunctionPass>()) return;
+
+    ++afterPassFailedCallbackCount;
+  }
+};
+
+TEST(PassManagerTest, PassInstrumentation) {
+  MLIRContext context;
+  context.loadDialect<func::FuncDialect>();
+  Builder b(&context);
+
+  // Create a module with 1 function.
+  OwningOpRef<ModuleOp> module(ModuleOp::create(UnknownLoc::get(&context)));
+  auto func = func::FuncOp::create(b.getUnknownLoc(), "test_func",
+                                   b.getFunctionType({}, {}));
+  func.setPrivate();
+  module->push_back(func);
+
+  struct InstrumentationCounts {
+    int beforePass;
+    int afterPass;
+    int afterPassFailed;
+  };
+
+  auto runInstrumentation =
+      [&](bool failBefore,
+          bool failAfter) -> std::pair<LogicalResult, InstrumentationCounts> {
+    // Instantiate and run our pass.
+    auto pm = PassManager::on<ModuleOp>(&context);
+    auto instrumentation = std::make_unique<TestPassInstrumentation>();
+    auto *instrumentationPtr = instrumentation.get();
+    instrumentation->failBeforePass = failBefore;
+    instrumentation->failAfterPass = failAfter;
+    pm.addInstrumentation(std::move(instrumentation));
+    pm.addNestedPass<func::FuncOp>(std::make_unique<AddAttrFunctionPass>());
+    LogicalResult result = pm.run(module.get());
+
+    InstrumentationCounts counts = {
+        .beforePass = instrumentationPtr->beforePassCallbackCount,
+        .afterPass = instrumentationPtr->afterPassCallbackCount,
+        .afterPassFailed = instrumentationPtr->afterPassFailedCallbackCount};
+    return {result, counts};
+  };
+
+  for (bool failBefore : {false, true}) {
+    for (bool failAfter : {false, true}) {
+      auto [result, counts] = runInstrumentation(failBefore, failAfter);
+
+      InstrumentationCounts expected;
+      if (failBefore) {
+        EXPECT_TRUE(failed(result))
+            << "failBefore=" << failBefore << ", failAfter=" << failAfter;
+        expected = {.beforePass = 1, .afterPass = 0, .afterPassFailed = 1};
+      } else if (failAfter) {
+        EXPECT_TRUE(failed(result))
+            << "failBefore=" << failBefore << ", failAfter=" << failAfter;
+        expected = {.beforePass = 1, .afterPass = 1, .afterPassFailed = 0};
+      } else {
+        EXPECT_TRUE(succeeded(result))
+            << "failBefore=" << failBefore << ", failAfter=" << failAfter;
+        expected = {.beforePass = 1, .afterPass = 1, .afterPassFailed = 0};
+      }
+
+      EXPECT_EQ(counts.beforePass, expected.beforePass)
+          << "failBefore=" << failBefore << ", failAfter=" << failAfter;
+      EXPECT_EQ(counts.afterPass, expected.afterPass)
+          << "failBefore=" << failBefore << ", failAfter=" << failAfter;
+      EXPECT_EQ(counts.afterPassFailed, expected.afterPassFailed)
+          << "failBefore=" << failBefore << ", failAfter=" << failAfter;
+    }
+  }
+}
+
 TEST(PassManagerTest, ExecutionAction) {
   MLIRContext context;
   context.loadDialect<func::FuncDialect>();

@github-actions
Copy link

⚠️ C/C++ code formatter, clang-format found issues in your code. ⚠️

You can test this locally with the following command:
git-clang-format --diff origin/main HEAD --extensions cpp,h -- mlir/include/mlir/Pass/Pass.h mlir/include/mlir/Pass/PassInstrumentation.h mlir/lib/Pass/Pass.cpp mlir/unittests/Pass/PassManagerTest.cpp --diff_from_common_commit

⚠️
The reproduction instructions above might return results for more than one PR
in a stack if you are using a stacked PR workflow. You can limit the results by
changing origin/main to the base branch/commit you want to compare against.
⚠️

View the diff from clang-format here.
diff --git a/mlir/lib/Pass/Pass.cpp b/mlir/lib/Pass/Pass.cpp
index 17ac475b4..9dc947a78 100644
--- a/mlir/lib/Pass/Pass.cpp
+++ b/mlir/lib/Pass/Pass.cpp
@@ -613,7 +613,6 @@ LogicalResult OpToOpPassAdaptor::run(Pass *pass, Operation *op,
         {op}, *pass);
   }
 
-
   // Invalidate any non preserved analyses.
   am.invalidate(pass->passState->preservedAnalyses);
 
@@ -1203,7 +1202,9 @@ void PassInstrumentation::runBeforePipeline(
 void PassInstrumentation::runAfterPipeline(
     std::optional<OperationName> name, const PipelineParentInfo &parentInfo) {}
 
-void PassInstrumentation::signalPassFailure(Pass *pass) { pass->signalPassFailure(); }
+void PassInstrumentation::signalPassFailure(Pass *pass) {
+  pass->signalPassFailure();
+}
 
 //===----------------------------------------------------------------------===//
 // PassInstrumentor
diff --git a/mlir/unittests/Pass/PassManagerTest.cpp b/mlir/unittests/Pass/PassManagerTest.cpp
index 86c793384..50cd8ee1c 100644
--- a/mlir/unittests/Pass/PassManagerTest.cpp
+++ b/mlir/unittests/Pass/PassManagerTest.cpp
@@ -128,21 +128,24 @@ struct TestPassInstrumentation : public PassInstrumentation {
   bool failAfterPass = false;
 
   void runBeforePass(Pass *pass, Operation *op) override {
-    if (pass->getTypeID() != TypeID::get<AddAttrFunctionPass>()) return;
+    if (pass->getTypeID() != TypeID::get<AddAttrFunctionPass>())
+      return;
 
     ++beforePassCallbackCount;
     if (failBeforePass)
       signalPassFailure(pass);
   }
   void runAfterPass(Pass *pass, Operation *op) override {
-    if (pass->getTypeID() != TypeID::get<AddAttrFunctionPass>()) return;
+    if (pass->getTypeID() != TypeID::get<AddAttrFunctionPass>())
+      return;
 
     ++afterPassCallbackCount;
     if (failAfterPass)
       signalPassFailure(pass);
   }
   void runAfterPassFailed(Pass *pass, Operation *op) override {
-    if (pass->getTypeID() != TypeID::get<AddAttrFunctionPass>()) return;
+    if (pass->getTypeID() != TypeID::get<AddAttrFunctionPass>())
+      return;
 
     ++afterPassFailedCallbackCount;
   }

/// current operation being analyzed.
virtual void runAfterAnalysis(StringRef name, TypeID id, Operation *op) {}

static void signalPassFailure(Pass *pass);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please add documentation for this method, also should this be a static method or an instance method? I know it doesn't have to be an instance method but that would help keep the scope of API exposure slimmer (otherwise, should we just make signalPassFailure public?)

passFailed = pass->passState->irAndPassFailed.getInt();
},
{op}, *pass);
bool passFailed = pass->passState->irAndPassFailed.getInt();
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That is non-intuitive to me, should be documented.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

mlir:core MLIR Core Infrastructure mlir

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants